Bijectors#

import numpy as np
import matplotlib.pyplot as plt
from scipy import stats
from rlxutils import subplots, copy_func
import tensorflow_probability as tfp
import tensorflow as tf
import ppdl 
import pandas as pd

tfd = tfp.distributions
tfb = tfp.bijectors

%matplotlib inline

Bijectors are invertible transformations#

When applied to TFP distributions bijectors produce fully valid distributions

TFP has a collection of Bijectors that allow you to transform distributions, keeping the capability for sampling and compute densities.

We will see them more in detail later in the course, but for now, let’s understand what they do.

Observe how we scale and shift a distribution.

bscale = tfb.Scale(.5)
bshift = tfb.Shift(.3)
d_orig = tfd.Beta(1.8,1.5)
d_scaled = bscale(d_orig)
d_scaled_and_shifted = bshift(d_scaled)

for ax,i in subplots(range(3), usizex=5):
    if i==0: ppdl.plot_pdf(d_orig); plt.title("original")
    if i==1: ppdl.plot_pdf(d_scaled); plt.title("scaled")
    if i==2: ppdl.plot_pdf(d_scaled_and_shifted); plt.title("scaled and shifted")
    plt.xlim(-.1,1.1)
    plt.grid();
plt.tight_layout()
../_images/683819dbf48884c7f431dd69e9965ad50cd77da0e29d7c027a14bce793e258e1.png

the resulting distribution are fully valid TFP distribution objects

d_scaled
<tfp.distributions.TransformedDistribution 'scaleBeta' batch_shape=[] event_shape=[] dtype=float32>
s = d_scaled.sample(10)
s
<tf.Tensor: shape=(10,), dtype=float32, numpy=
array([0.11156841, 0.39035392, 0.07765947, 0.19926141, 0.3500817 ,
       0.19258274, 0.16098109, 0.31810397, 0.10686599, 0.09793448],
      dtype=float32)>
d_scaled.log_prob(s)
<tf.Tensor: shape=(10,), dtype=float32, numpy=
array([0.5458902 , 0.91539323, 0.29789424, 0.8819361 , 0.9846997 ,
       0.86564505, 0.77117884, 1.0047415 , 0.5174571 , 0.4588679 ],
      dtype=float32)>

Chaining and inveting bijectors#

you can chain bijectors to crete a new bijector. Observe that they are specied in inverse order to which they are applied.

you can also create an inverse bijector.

bc     = tfb.Chain([bshift,bscale])
bci    = tfb.Invert(bc)

dt     = bc(d_orig)
d_back = bci(dt)

for ax,i in subplots(3, usizex=5):
    if i==0: ppdl.plot_pdf(d_orig); plt.title("original distribution")
    if i==1: ppdl.plot_pdf(dt); plt.title("chain transformed")
    if i==2: ppdl.plot_pdf(d_back); plt.title("transformed back")
    plt.xlim(-.1,1.1)
    plt.grid();
../_images/319cf3992974f453603fd9708371729aba00cd0bec1db5659490662e6fec4a15.png

Bijector are general transformations on TF objects#

x = tf.Variable(2.)
tx = bc(x)
tx
<tf.Tensor: shape=(), dtype=float32, numpy=1.3>
bc.inverse(tx)
<tf.Tensor: shape=(), dtype=float32, numpy=2.0>

Modelling stuff#

bijectors make it easier to model stuff with non standard distributions, for instance

t = tfb.Chain([tfb.Scale(3), tfb.Shift(2), tfb.Tanh(), tfb.Scale(1.5), tfb.Shift(+.5)])
d1 = tfd.Beta(1.2,1.8)
d2 = t(d1)

for ax,i in subplots(2, usizex=5):
    if i==0: ppdl.plot_pdf(d1); plt.title("original")
    if i==1: ppdl.plot_pdf(d2); plt.title("transformed")
    plt.grid();
plt.tight_layout()
../_images/063c72a81803e1003f1e7a57d5b958924a03022b56d3c08fdf96cbf1c6e01801.png

validate_args#

Observe that sometimes a bijector might not be invertible for certain input values.

This is only checked in validate_args is set to True, otherwise, nan are generated and the code downstream will fail somewhere

t = tfb.Square()
d = tfd.Normal(loc=0, scale=1)
dt = t(d)
s = dt.sample(10)
s
<tf.Tensor: shape=(10,), dtype=float32, numpy=
array([0.05243969, 0.00711866, 0.3022964 , 0.08465093, 1.0127188 ,
       1.7130485 , 0.5749669 , 0.41861057, 0.9270925 , 0.80052745],
      dtype=float32)>
dt.prob(s)
<tf.Tensor: shape=(10,), dtype=float32, numpy=
array([0.84852153, 2.3557818 ,        nan, 0.6571772 , 0.1194611 ,
       0.06471597, 0.19733593,        nan,        nan,        nan],
      dtype=float32)>

with validate_args set to True an exception is rised and can be dealt with

t = tfb.Square(validate_args=True)
d = tfd.Normal(loc=0, scale=1)
dt = t(d)
try:
    s = dt.sample(10)
except Exception as e:
    print (e)
All elements must be non-negative..  
Condition x >= 0 did not hold element-wise:
x (shape=(10,) dtype=float32) = 
['-0.9586899', '0.72535914', '1.2406634', '...']

but the code is slower even for a simple scenario. We test both cases with positive samples which we know are valid

t = tfb.Chain([tfb.Scale(2, validate_args=True), tfb.Square(validate_args=True)])
d = tfd.Normal(loc=100, scale=1)
dt = t(d)
%timeit s = dt.sample(1000)
%timeit dt.log_prob(s)
3.12 ms ± 34 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
9.94 ms ± 252 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
t = tfb.Square()
t = tfb.Chain([tfb.Scale(2), tfb.Square()])
d = tfd.Normal(loc=100, scale=1)
dt = t(d)
%timeit s = dt.sample(1000)
%timeit dt.log_prob(s)
2.92 ms ± 52.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
8.89 ms ± 413 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

Bijector parameters are learnable#

observe how we can learn altogether bijector and distribution parameters

scale = tfb.Scale(3)
a,b = 2, 1.5
dbeta = tfd.Beta(a,b)
d_orig = scale(dbeta)  

x = d_orig.sample(100000)
ppdl.plot_pdf(d_orig, hist_args={'color': 'red', 'bins': 100, 'alpha': .5})
plt.axvline(np.mean(x), ls="--", color="black", alpha=.5, label="sample mean")
plt.grid(); plt.legend();
../_images/7e4801b095ebab162782a83d7f49e62b9845334d8788c400a89b21562e275fdf.png

Although trivial, we will just learn the bijector parameter.

def optimize(init_sc=10., validate_args=False):
    
    sc = tf.Variable(init_sc, dtype=tf.float32)

    optimizer = tf.keras.optimizers.Adam(learning_rate=0.1)    

    hloss = []
    hgrads = []
    hparams = []
    for epoch in pbar(range(200)):

        with tf.GradientTape() as tape:
            c = tfb.Scale(sc)
            d = c(tfd.Beta(a,b, validate_args=validate_args))
            negloglik    = -tf.reduce_mean(d.log_prob(x))
            hloss.append(negloglik.numpy())

        gradients = tape.gradient(negloglik, [sc])
        if np.any([np.isnan(i) for i in gradients]):
            print("nan gradients")
            break
        optimizer.apply_gradients(zip(gradients,[sc]))
        hgrads.append(gradients)
        hparams.append(sc.numpy())

    hgrads = np.r_[hgrads]
    hloss = np.r_[hloss]
    hparams = np.r_[hparams]
    
    return hloss, hparams, hgrads, sc
    
def plot_optim(hloss, hparams, hgrads):
    for ax,i in subplots(3, usizex=5):
        if i==0: plt.plot(hloss); plt.title("loss")
        if i==1: plt.plot(hparams); plt.title("parameter value")
        if i==2: plt.plot(hgrads[:,0]); plt.title("paramter gradient")
        plt.xlabel("epoch")
        plt.grid();
hloss, hparams, hgrads, sc = optimize()
sc
100% (200 of 200) |######################| Elapsed Time: 0:00:03 Time:  0:00:03
<tf.Variable 'Variable:0' shape=() dtype=float32, numpy=3.1905434>
plot_optim(hloss, hparams, hgrads)
../_images/0dc63d913331918545dbae23f5e8591db2fb9a0d899316c0d4a9f71a7e2b00d1.png
ppdl.plot_pdf(d, hist_args={'label': 'fitted distribution', 'bins': 100, 'alpha': .5})
plt.hist(x.numpy(), bins=100, density=True, alpha=.5, color="red", label="original distribution")
plt.grid(); plt.legend();
../_images/63a8a430cf28ffc5dccea9753a24e066c49c282805c4aeaaf1828d68787bf419.png

observe that we might hit invalid values during optimization. The Beta distribution returns nan if trying to compute densities for values outside its domain.

hloss, hparams, hgrads, sc = optimize(init_sc=-5)
 75% (150 of 200) |################      | Elapsed Time: 0:00:02 ETA:   0:00:00
nan gradients
plot_optim(hloss, hparams, hgrads)
../_images/7eacfaa76eeb9756a3bbb31eb05f55b023358db82793bea8107f83e43727e106.png

with validate_args set to True TFP checks stuff and will hint these situations

hloss, hparams, hgrads, sc = optimize(init_sc=-5, validate_args=True)
  0% (0 of 200) |                        | Elapsed Time: 0:00:00 ETA:  --:--:--
---------------------------------------------------------------------------
InvalidArgumentError                      Traceback (most recent call last)
/tmp/ipykernel_53908/1940918081.py in <module>
----> 1 hloss, hparams, hgrads, sc = optimize(init_sc=-5, validate_args=True)

/tmp/ipykernel_53908/734260892.py in optimize(init_sc, validate_args)
     13             c = tfb.Scale(sc)
     14             d = c(tfd.Beta(a,b, validate_args=validate_args))
---> 15             negloglik    = -tf.reduce_mean(d.log_prob(x))
     16             hloss.append(negloglik.numpy())
     17 

/opt/anaconda/envs/p39/lib/python3.9/site-packages/tensorflow_probability/python/distributions/distribution.py in log_prob(self, value, name, **kwargs)
   1314         values of type `self.dtype`.
   1315     """
-> 1316     return self._call_log_prob(value, name, **kwargs)
   1317 
   1318   def _call_prob(self, value, name, **kwargs):

/opt/anaconda/envs/p39/lib/python3.9/site-packages/tensorflow_probability/python/distributions/distribution.py in _call_log_prob(self, value, name, **kwargs)
   1296     with self._name_and_control_scope(name, value, kwargs):
   1297       if hasattr(self, '_log_prob'):
-> 1298         return self._log_prob(value, **kwargs)
   1299       if hasattr(self, '_prob'):
   1300         return tf.math.log(self._prob(value, **kwargs))

/opt/anaconda/envs/p39/lib/python3.9/site-packages/tensorflow_probability/python/distributions/transformed_distribution.py in _log_prob(self, y, **kwargs)
    368         y, event_ndims=event_ndims, **bijector_kwargs)
    369     if self.bijector._is_injective:  # pylint: disable=protected-access
--> 370       base_log_prob = self.distribution.log_prob(x, **distribution_kwargs)
    371       return base_log_prob + tf.cast(ildj, base_log_prob.dtype)
    372 

/opt/anaconda/envs/p39/lib/python3.9/site-packages/tensorflow_probability/python/distributions/distribution.py in log_prob(self, value, name, **kwargs)
   1314         values of type `self.dtype`.
   1315     """
-> 1316     return self._call_log_prob(value, name, **kwargs)
   1317 
   1318   def _call_prob(self, value, name, **kwargs):

/opt/anaconda/envs/p39/lib/python3.9/site-packages/tensorflow_probability/python/distributions/distribution.py in _call_log_prob(self, value, name, **kwargs)
   1294         value, name='value', dtype_hint=self.dtype,
   1295         allow_packing=True)
-> 1296     with self._name_and_control_scope(name, value, kwargs):
   1297       if hasattr(self, '_log_prob'):
   1298         return self._log_prob(value, **kwargs)

/opt/anaconda/envs/p39/lib/python3.9/contextlib.py in __enter__(self)
    117         del self.args, self.kwds, self.func
    118         try:
--> 119             return next(self.gen)
    120         except StopIteration:
    121             raise RuntimeError("generator didn't yield") from None

/opt/anaconda/envs/p39/lib/python3.9/site-packages/tensorflow_probability/python/distributions/distribution.py in _name_and_control_scope(self, name, value, kwargs)
   1995         deps.extend(self._parameter_control_dependencies(is_init=False))
   1996         if value is not UNSET_VALUE:
-> 1997           deps.extend(self._sample_control_dependencies(
   1998               value, **({} if kwargs is None else kwargs)))
   1999         if not deps:

/opt/anaconda/envs/p39/lib/python3.9/site-packages/tensorflow_probability/python/distributions/beta.py in _sample_control_dependencies(self, x)
    335     if not self.validate_args:
    336       return assertions
--> 337     assertions.append(assert_util.assert_non_negative(
    338         x, message='Sample must be non-negative.'))
    339     assertions.append(assert_util.assert_less_equal(

/opt/anaconda/envs/p39/lib/python3.9/site-packages/tensorflow/python/util/traceback_utils.py in error_handler(*args, **kwargs)
    151     except Exception as e:
    152       filtered_tb = _process_traceback_frames(e.__traceback__)
--> 153       raise e.with_traceback(filtered_tb) from None
    154     finally:
    155       del filtered_tb

/opt/anaconda/envs/p39/lib/python3.9/site-packages/tensorflow/python/ops/check_ops.py in _binary_assert(sym, opname, op_func, static_func, x, y, data, summarize, message, name)
    407         data = [message] + list(data)
    408 
--> 409       raise errors.InvalidArgumentError(
    410           node_def=None,
    411           op=None,

InvalidArgumentError: Sample must be non-negative..  
Condition x >= 0 did not hold element-wise:
x (shape=(100000,) dtype=float32) = 
['-0.44549665', '-0.05629177', '-0.39321715', '...']